b5d61b867c9d01c1902ab91df778967e98790c8a,h2o-core/src/main/java/hex/ModelBuilder.java,ModelBuilder,computeCrossValidation,#,287

Before Change


    final Integer N = nFoldWork();
    final Vec foldAssignment;
    if (_parms._fold_column != null) {
      foldAssignment = origTrainFrame.vec(_parms._fold_column);
    } else {
      final long seed = _parms.nFoldSeed();
      Log.info("Creating " + N + " cross-validation splits with random number seed: " + seed);

After Change


    // TODO: Implement better splitting algo (with Strata if response is categorical), e.g. http://www.lexjansen.com/scsug/2009/Liang_Xie2.pdf
    final Integer N = nFoldWork();
    final Vec foldAssignment;
    final Vec foldCol = origTrainFrame.vec(_parms._fold_column);
    if (_parms._fold_column != null) {
      foldAssignment = VecUtils.toCategoricalVec(foldCol);
    } else {
      final long seed = _parms.nFoldSeed();
      Log.info("Creating " + N + " cross-validation splits with random number seed: " + seed);
      switch( _parms._fold_assignment ) {
      case AUTO:
      case Random:     foldAssignment = ASTKFold.          kfoldColumn(    zTmp(),N,seed); break;
      case Modulo:     foldAssignment = ASTKFold.    moduloKfoldColumn(    zTmp(),N     ); break;
      case Stratified: foldAssignment = ASTKFold.stratifiedKFoldColumn(response(),N,seed); break;
      default:         throw H2O.unimpl();
      }
    }

    final Key[] modelKeys = new Key[N];
    final Key[] predictionKeys = new Key[N];

    // Step 2: Make 2*N binary weight vectors and store the CV train/validation frames
    final String origWeightsName = _parms._weights_column;
    final Vec[] weights = new Vec[2*N];
    final Vec origWeight  = origWeightsName != null ? origTrainFrame.vec(origWeightsName) : origTrainFrame.anyVec().makeCon(1.0);
    final Frame[] cvTrain = new Frame[N];
    final Frame[] cvValid = new Frame[N];
    final String[] identifier = new String[N];
    final String weightName = "__internal_cv_weights__";
    if (train().find(weightName) != -1) throw new H2OIllegalArgumentException("Frame cannot contain a Vec called '" + weightName + "'.");

    final Key<M> origDest = dest();
    for (int i=0; i<N; ++i) {
      // Make weights
      weights[2*i]   = zTmp();
      weights[2*i+1] = zTmp();

      // Now update the weights in place
      final int whichFold = i;
      new MRTask() {
        @Override
        public void map(Chunk chks[]) {
          Chunk fold = chks[0];
          Chunk orig = chks[1];
          Chunk train = chks[2];
          Chunk valid = chks[3];
          for (int i=0; i< orig._len; ++i) {
            int foldAssignment = (int)fold.at8(i) % N;
            assert(foldAssignment >= 0 && foldAssignment <N);
            boolean holdout = foldAssignment == whichFold;
            double w = orig.atd(i);
            train.set(i, holdout ? 0 : w);
            valid.set(i, holdout ? w : 0);
          }
        }
      }.doAll(new Vec[]{foldAssignment, origWeight, weights[2*i], weights[2*i+1]});
      if (weights[2*i].isConst() || weights[2*i+1].isConst()) {
        String msg = "Not enough data to create " + N + " random cross-validation splits. Either reduce nfolds, specify a larger dataset (or specify another random number seed, if applicable).";
        throw new H2OIllegalArgumentException(msg);
      }

      identifier[i] = origDest.toString() + "_cv_" + (i+1);
      modelKeys[i] = Key.make(identifier[i]);

      // Training/Validation share the same data, but will have exclusive weights
      cvTrain[i] = new Frame(Key.make(identifier[i]+"_"+_parms._train.toString()+"_train"), origTrainFrame.names(), origTrainFrame.vecs());
      if (origWeightsName!=null) cvTrain[i].remove(origWeightsName);
      cvTrain[i].add(weightName, weights[2*i]);
      DKV.put(cvTrain[i]);
      cvValid[i] = new Frame(Key.make(identifier[i]+"_"+_parms._train.toString()+"_valid"), origTrainFrame.names(), origTrainFrame.vecs());
      if (origWeightsName!=null) cvValid[i].remove(origWeightsName);
      cvValid[i].add(weightName, weights[2*i+1]);
      DKV.put(cvValid[i]);
    }

    // clean up memory (mostly small helper vectors and Frame headers)
    if (foldAssignment != foldCol || _parms._fold_column == null) foldAssignment.remove();
    if (origWeightsName == null) origWeight.remove();

    // adapt main Job's progress bar to build N+1 models